# -*- coding: utf-8 -*-
"""
Created on Tue Dec 21 09:19:12 2021

@author: wulong
"""
import numpy as np
from casadi import *

from Real_system import *
from Long_term_model import *
from Slow_model import *
from Medium_model import *
from Fast_model import *

from Long_term_formulate_Opt import *
from Slow_formulate_Opt import *
from Medium_formulate_Opt import *
from Fast_formulate_Opt import *

from Long_term_MPC_solver import *
from Slow_MPC_solver import *
from Meduim_MPC_solver import *
from Fast_MPC_solver import *

from Weight_evaluation import *
from time import time

time_start = time()


# In[1] Common parameters
delta_zone = 0.5
# all input & state bounds
input_cont_bnds_lo_r = np.array([0.00055, 0.002, 1.9, 20, 1.3, 0])
input_cont_bnds_up_r = np.array([0.0045, 0.0066911, 3.5714, 110, 2.381, 1])
input_bin_bnds_lo_r = np.array([0, 0, 0, 0])
input_bin_bnds_up_r = np.array([1, 1, 1, 1])

state_bnds_lo_r = np.array([13, 0.03, 0.006, 0.095, 0.002, -75, -0.5, -4, -5, 
                            27, 25, 25, -5, 0, 7, 
                            -0.19, 0.1, -145, 2000, 2000, 0, 8, 18])
state_bnds_up_r = np.array([130, 0.4, 0.06, 1, 0.08, 5.00, 3, 1, 10, 
                            44, 38, 35, 20, 22, 22, 
                            0.19, 0.9, 155, 20000, 240000, 324000, 18, 26])
# normalized weight coefficients
weights = Weight_MPC(input_cont_bnds_lo_r, input_cont_bnds_up_r, 
                     state_bnds_lo_r, state_bnds_up_r)
weight_y1 = weights[0]
weight_y2 = weights[1]
weight_u1_u2 = weights[2]
weight_u1 = weights[3]
weight_state_s = weights[4]
weight_state_m = weights[5]
weight_state_f = weights[6]
weight_input_s = weights[7]
weight_input_m = weights[8]
weight_input_f = weights[9]

weight_state_slack = np.array([2.5e3,2.5e3,4.44e-3,4.44e-3])

# Disturbance & setpoint sequences
# load data in 1 hour and 10 minutes, data type:[ta, ins, Pd, Qother, tbr]
data_l_temp = np.loadtxt("data_long.txt", delimiter = ', ')
data_r_temp = np.loadtxt("data_real.txt", delimiter = ', ')

# Double data for prediction except initial point
data_l_temp_1 = np.tile(data_l_temp[1:,:], (2,1))
data_r_temp_1 = np.tile(data_r_temp[1:,:], (2,1))
# detailed data for real-time
data_r_temp_2 = np.repeat(data_r_temp_1, 10*60/Delta_r, axis = 0)
# real data including initial point
data_l = np.vstack((data_l_temp[0,:], data_l_temp_1))
data_r = np.vstack((data_r_temp[0,:], data_r_temp_2))

# real data calculation
# for disturbance [ta, ins, Pd, Qother]
distb_r = data_r[:,0:4]
# output for Pd
output_setpnt_r = data_r[:,2]
# output zone for tbr
output_setpnt_lo_r = data_r[:,-1] - delta_zone
output_setpnt_up_r = data_r[:,-1] + delta_zone


# In[2] L-MPC parameters, constraints and bounds
pred_horzn_l = 24

alpha_l = 1e3
alpha_slack_l = np.array([4,1e2,1e2])

seed_l = int(Delta_l/Delta_r)
seed_l_s = int(Delta_l/Delta_s)

opt_xk_s_l = np.zeros((2*seed_l_s+1,2))

state_bnds_lo_l = np.array([state_bnds_lo_r[16], state_bnds_lo_r[18], state_bnds_lo_r[22]])
state_bnds_up_l = np.array([state_bnds_up_r[16], state_bnds_up_r[18], state_bnds_up_r[22]])

input_cont_bnds_lo_l = np.array([0,0,0,0,-30])
input_cont_bnds_up_l = np.array([input_cont_bnds_up_r[0], input_cont_bnds_up_r[1],
                                 input_cont_bnds_up_r[3], input_cont_bnds_up_r[5], 30])
input_cont_lbg_l = np.array([input_cont_bnds_lo_r[0], input_cont_bnds_lo_r[1], 
                             input_cont_bnds_lo_r[3], input_cont_bnds_lo_r[5], -30])
input_cont_ubg_l = np.array([input_cont_bnds_up_r[0], input_cont_bnds_up_r[1], 
                             input_cont_bnds_up_r[3], input_cont_bnds_up_r[5], 30])

input_bin_bnds_lo_l = input_bin_bnds_lo_r
input_bin_bnds_up_l = input_bin_bnds_up_r

# for disturbance [ta, ins, Pd, Qother]
distb_bnds_l = np.delete(data_l, np.array([2,4]), axis=1)
# output for Pd
output_setpnts_l = data_l[:,2]
# output zone for tbr
output_setpnts_lo_l = data_l[:,-1] - delta_zone
output_setpnts_up_l = data_l[:,-1] + delta_zone


# In[3] S-MPC parameters, constraints and bounds
pred_horzn_s = 12

alpha_s = np.array([weight_y1, weight_y2, weight_u1_u2, 
                    weight_state_s[0], weight_state_s[1]])*np.array([5,5,5,8,8])
alpha_slack_s = np.diag(weight_state_slack)*0.1

seed_s = int(Delta_s/Delta_r)

state_bnds_lo_s = np.append(np.insert(state_bnds_lo_r[18:21], 0, 
                                      state_bnds_lo_r[16]), state_bnds_lo_r[22])
state_bnds_up_s = np.append(np.insert(state_bnds_up_r[18:21], 0, 
                                      state_bnds_up_r[16]), state_bnds_up_r[22])
medium_state_bnds_lo_s = np.append(state_bnds_lo_r[3:9], 
                                   np.array([state_bnds_lo_r[11], state_bnds_lo_r[14], 
                                             state_bnds_lo_r[21]]))
medium_state_bnds_up_s = np.append(state_bnds_up_r[3:9], 
                                   np.array([state_bnds_up_r[11], state_bnds_up_r[14], 
                                             state_bnds_up_r[21]]))
fast_state_bnds_lo_s = np.delete(state_bnds_lo_r, 
                                 np.array((3,4,5,6,7,8,11,14,16,18,19,20,21,22)))
fast_state_bnds_up_s = np.delete(state_bnds_up_r, 
                                 np.array((3,4,5,6,7,8,11,14,16,18,19,20,21,22)))
fast_state_bnds_lo_s[7:9] = np.array([-np.inf, -np.inf])
fast_state_bnds_up_s[7:9] = np.array([np.inf, np.inf])

fast_state_slack_lo_s = np.array([state_bnds_lo_r[15], state_bnds_lo_r[17]])
fast_state_slack_up_s = np.array([state_bnds_up_r[15], state_bnds_up_r[17]])


input_bnds_lo_s = np.insert(input_cont_bnds_lo_r[4:], 0, input_cont_bnds_lo_r[2])
input_bnds_up_s = np.insert(input_cont_bnds_up_r[4:], 0, input_cont_bnds_up_r[2])
medium_input_bnds_lo_s = input_cont_bnds_lo_r[1]
medium_input_bnds_up_s = input_cont_bnds_up_r[1]
fast_input_bnds_lo_s = np.array([input_cont_bnds_lo_r[0], input_cont_bnds_lo_r[3]])
fast_input_bnds_up_s = np.array([input_cont_bnds_up_r[0], input_cont_bnds_up_r[3]])

# S-MPC setpoint calculation
distb_bnds_s = distb_r[::seed_s, :]
output_setpnts_s = output_setpnt_r[::seed_s]
yspe_bnds_lo_s = output_setpnt_lo_r[::seed_s]
yspe_bnds_up_s = output_setpnt_up_r[::seed_s]


# In[4] M-MPC parameters, constraints and bounds
pred_horzn_m = 12

alpha_m = np.array([weight_y1, weight_u1_u2])*np.array([8,1])
alpha_xm_m = np.diag(weight_state_m)*8
alpha_xf_m = np.diag(weight_state_f)*8
alpha_um_m = weight_input_m*8
alpha_uf_m = np.diag(weight_input_f)*8
alpha_slack_m = np.diag(weight_state_slack)*0.1

seed_m = int(Delta_m_temp/Delta_r)

state_bnds_lo_m = medium_state_bnds_lo_s
state_bnds_up_m = medium_state_bnds_up_s
fast_state_bnds_lo_m = fast_state_bnds_lo_s
fast_state_bnds_up_m = fast_state_bnds_up_s

fast_state_slack_lo_m = fast_state_slack_lo_s
fast_state_slack_up_m = fast_state_slack_up_s

input_bnds_lo_m = medium_input_bnds_lo_s
input_bnds_up_m = medium_input_bnds_up_s
fast_input_bnds_lo_m = fast_input_bnds_lo_s
fast_input_bnds_up_m = fast_input_bnds_up_s

# M-MPC setpoint calculation
distb_bnds_m = distb_r[::seed_m, :]
output_setpnts_m = output_setpnt_r[::seed_m]


# In[5] F-MPC parameters, constraints and bounds
pred_horzn_f = 4

alpha_f = np.array([weight_y1, weight_u1])*np.array([8,1])
alpha_xf_f = np.diag(weight_state_f)*8
alpha_uf_f = np.diag(weight_input_f)*8
alpha_slack_f = np.diag(weight_state_slack)*0.1

seed_f = int(Delta_f_temp/Delta_r)

state_bnds_lo_f = fast_state_bnds_lo_s
state_bnds_up_f = fast_state_bnds_up_s

fast_state_slack_lo_f = fast_state_slack_lo_s
fast_state_slack_up_f = fast_state_slack_up_s

input_bnds_lo_f = fast_input_bnds_lo_s
input_bnds_up_f = fast_input_bnds_up_s

# F-MPC setpoint calculation
distb_bnds_f = distb_r[::seed_f, :]
output_setpnts_f = output_setpnt_r[::seed_f]


# In[6] Simulate things
Simtime = 3600*24
Nsim = int(Simtime/Delta_r)

T = np.reshape(np.arange(Nsim+1)*Delta_r, (Nsim+1, 1))
X = np.zeros((Nsim+1, Nx_r))
Y = np.zeros((Nsim+1, Ny_r))
U = np.zeros((Nsim+1, Nuc_r))
Z = np.zeros((Nsim+1, Nuz_r))


# recording setopoints
start_r = 0
Y_setpnt_r = output_setpnt_r[start_r : start_r + Nsim+1]
Y_zone_setpnt_r = np.zeros((Nsim+1, 2))
Y_zone_setpnt_r[:,0] = output_setpnt_lo_r[start_r : start_r + Nsim+1]
Y_zone_setpnt_r[:,1] = output_setpnt_up_r[start_r : start_r + Nsim+1]
D_r = distb_r[start_r : start_r + Nsim+1, :]


start_l = 0
U_l_solution = np.zeros((int(np.ceil(Nsim/seed_l)), Nuc_l))
X_l_solution = np.zeros((int(np.ceil(Nsim/seed_l)), Nx_l))
time_LMPC = np.zeros(24)

Y_setpnt_l_temp = output_setpnts_l[start_l : start_l + 25]
Y_setpnt_l_temp_1 = np.repeat(Y_setpnt_l_temp[1:], seed_l, axis = 0)
Y_setpnt_l = np.hstack((Y_setpnt_l_temp[0], Y_setpnt_l_temp_1))

Y_zone_setpnt_l_temp = np.zeros((25, 2))
Y_zone_setpnt_l_temp[:,0] = output_setpnts_lo_l[start_l : start_l + 25]
Y_zone_setpnt_l_temp[:,1] = output_setpnts_up_l[start_l : start_l + 25]
Y_zone_setpnt_l_temp_1 = np.repeat(Y_zone_setpnt_l_temp[1:,:], seed_l, axis = 0)
Y_zone_setpnt_l = np.vstack((Y_zone_setpnt_l_temp[0,:], Y_zone_setpnt_l_temp_1))

D_l_temp = distb_bnds_l[start_l : start_l + 25, :]
D_l_temp_1 = np.repeat(D_l_temp[1:,:], seed_l, axis = 0)
D_l = np.vstack((D_l_temp[0,:], D_l_temp_1))


# initial point
initial_state = np.array([112.952, 0.28125, 0.052831, 0.800712, 0.066726, 0, 0, 0, 0, 
                          39.9997, 34.6358, 30.5001, 2.00017, 6.23489, 9.50001, 
                          0, 0.2, 0, 5000, 35000, 180000, 12, 22])

initial_output = np.array([93.5, 22])
initial_input_cont = np.array([0.0045, 0.0066911, 3.5714, 110, 2.381, 0])
initial_input_bin = np.array([1, 1, 1, 1])

# Record initial parameters
X[0,:] = initial_state
Y[0,:] = initial_output

# Preload last time instant parameters
X_int = X[0,:]
Y_int = Y[0,:]
U_int = initial_input_cont
Z_int = initial_input_bin
Z_int_0 = np.array([0, 1, 0, 1])

# In[7] Composite MPC simulation

for i in range (1, Nsim+1):
    
    # Solve L-MPC
    if (i - 1) % seed_l == 0:
        time_start_l = time()
        i_l = int((i - 1)/seed_l + 1)
        
        Opt_l = L_MPC(X_int, state_bnds_lo_l, state_bnds_up_l, 
                      U_int, input_cont_bnds_lo_l, input_cont_bnds_up_l, 
                      input_cont_lbg_l, input_cont_ubg_l, 
                      Z_int_0, input_bin_bnds_lo_l, input_bin_bnds_up_l, 
                      distb_bnds_l, 
                      Y_int, 
                      output_setpnts_l, output_setpnts_lo_l, output_setpnts_up_l, 
                      alpha_l, alpha_slack_l, pred_horzn_l, i_l)
       
        opt_zk_l = Opt_l[1][0,:]
        
        opt_xk_s_l_temp = Opt_l[2][:2,:2]
        opt_xk_s_l_temp_1 = np.linspace(opt_xk_s_l_temp[0,0],
                                        opt_xk_s_l_temp[1,0], num = seed_l_s+1)
        opt_xk_s_l_temp_2 = np.linspace(opt_xk_s_l_temp[0,1],
                                        opt_xk_s_l_temp[1,1], num = seed_l_s+1)
        
        opt_xk_s_l_temp_3 = Opt_l[2][1:3,:2]
        opt_xk_s_l_temp_4 = np.linspace(opt_xk_s_l_temp_3[0,0],
                                        opt_xk_s_l_temp_3[1,0], num = seed_l_s+1)
        opt_xk_s_l_temp_5 = np.linspace(opt_xk_s_l_temp_3[0,1],
                                        opt_xk_s_l_temp_3[1,1], num = seed_l_s+1)
        
        opt_xk_s_l_temp_6 = np.concatenate((opt_xk_s_l_temp_1, np.delete(opt_xk_s_l_temp_4,0)))
        opt_xk_s_l_temp_7 = np.concatenate((opt_xk_s_l_temp_2, np.delete(opt_xk_s_l_temp_5,0)))
        
        opt_xk_s_l[:,0] = opt_xk_s_l_temp_6
        opt_xk_s_l[:,1] = opt_xk_s_l_temp_7
        
        U_l_solution[i_l - 1, :] = Opt_l[0][0,:]
        X_l_solution[i_l - 1, :] = Opt_l[2][1,:]
        U_l_last_solution = Opt_l[0]
        Z_l_last_solution = Opt_l[1]
        X_l_last_solution = Opt_l[2]
        
        time_end_l = time()
        print('L-MPC time: '+str(time_end_l - time_start_l))
        time_LMPC[i_l - 1] = time_end_l - time_start_l
        
        pass
    
    # Solve S-MPC
    if (i - 1) % seed_s == 0:
        time_start_s = time()
        i_s = int((i - 1)/seed_s + 1)
        i_l_s = int(i_s - (i_l - 1)*seed_l_s)
        
        Opt_s = S_MPC(X_int, state_bnds_lo_s, state_bnds_up_s, 
                      medium_state_bnds_lo_s, medium_state_bnds_up_s, 
                      fast_state_bnds_lo_s, fast_state_bnds_up_s, 
                      fast_state_slack_lo_s, fast_state_slack_up_s, 
                      U_int, input_bnds_lo_s, input_bnds_up_s, 
                      medium_input_bnds_lo_s, medium_input_bnds_up_s, 
                      fast_input_bnds_lo_s, fast_input_bnds_up_s, 
                      opt_zk_l,
                      distb_bnds_s,
                      Y_int, 
                      yspe_bnds_lo_s, yspe_bnds_up_s,
                      output_setpnts_s, opt_xk_s_l, 
                      alpha_s, alpha_slack_s, pred_horzn_s, i_s, i_l_s)
        
        opt_uk_s = Opt_s[0][0,:]
        opt_uk_m_s = Opt_s[1][0,:]
        opt_uk_f_s = Opt_s[2][0,:]
        opt_xk_m_s = Opt_s[3][0,:]
        opt_xk_f_s = Opt_s[4][0,:]
        X_int_s = np.append(np.insert(X_int[18:21], 0, X_int[16]), X_int[22])
        
        time_end_s = time()        
        print('S-MPC time: '+str(time_end_s - time_start_s))
        pass
    
    # Solve M-MPC
    if (i - 1) % seed_m == 0:
        time_start_m = time()
        i_m = int((i - 1)/seed_m + 1)
        
        Opt_m = M_MPC(X_int, state_bnds_lo_m, state_bnds_up_m, 
                      X_int_s, 
                      fast_state_bnds_lo_m, fast_state_bnds_up_m, 
                      fast_state_slack_lo_m, fast_state_slack_up_m, 
                      opt_uk_s, 
                      U_int, input_bnds_lo_m, input_bnds_up_m, 
                      fast_input_bnds_lo_m, fast_input_bnds_up_m, 
                      opt_zk_l,
                      distb_bnds_m,
                      Y_int, 
                      output_setpnts_m, opt_xk_m_s, opt_xk_f_s, 
                      opt_uk_m_s, opt_uk_f_s, 
                      alpha_m, alpha_xm_m, alpha_xf_m, alpha_um_m, alpha_uf_m, 
                      alpha_slack_m, pred_horzn_m, i_m)
        
        opt_uk_m = Opt_m[0][0,:]
        opt_uk_f_m = Opt_m[1][0,:]
        opt_xk_f_m = Opt_m[2][0,:]
        X_int_m = np.append(X_int[3:9], np.array([X_int[11], X_int[14], X_int[21]]))
        
        time_end_m = time()
        print('M-MPC time: '+str(time_end_m - time_start_m))
        pass
    
    # Solve F-MPC
    if (i - 1) % seed_f == 0:
        time_start_f = time()
        i_f = int((i - 1)/seed_f + 1)
        
        Opt_f = F_MPC(X_int, state_bnds_lo_f, state_bnds_up_f, 
                      fast_state_slack_lo_f, fast_state_slack_up_f, 
                      X_int_s, 
                      X_int_m, 
                      opt_uk_s, 
                      opt_uk_m, 
                      U_int, input_bnds_lo_f, input_bnds_up_f, 
                      opt_zk_l,
                      distb_bnds_f,
                      Y_int, 
                      output_setpnts_f, opt_xk_f_m, opt_uk_f_m, 
                      alpha_f, alpha_xf_f, alpha_uf_f, alpha_slack_f, pred_horzn_f, i_f)
        
        opt_uk_f = Opt_f[0][0,:]
        
        time_end_f = time()
        print('F-MPC time: '+str(time_end_f - time_start_f))
        pass
    
    # Combine real manipulated input & simulation, record, update
    uk = np.array([opt_uk_f[0], opt_uk_m[0], opt_uk_s[0], 
                   opt_uk_f[1], opt_uk_s[1], opt_uk_s[2]])
    zk = opt_zk_l
    dk = distb_r[i,:]
    
    Ik = I_ode_r(x0 = X_int, p = vertcat(uk, zk, dk))
    xk = Ik['xf'].full().ravel()
    
    yk = out_ies_r(xk, uk, zk, dk).full().ravel()
    
    U[i-1,:] = uk
    Z[i-1,:] = zk
    X[i,:] = xk
    Y[i,:] = yk
    
    U_int = U[i-1,:]
    Z_int = Z[i-1,:]
    X_int = X[i,:]
    Y_int = Y[i,:]
    
    print('Simulation time: ' + str(i*Delta_r))
    
    pass

# Adding inputs at last point
U[-1,:] = U[-2,:]
Z[-1,:] = Z[-2,:]

time_end = time()
print('Composite MPC time: '+str(time_end - time_start))


# In[8] Save data
np.savetxt("CMPC_T.txt", T, fmt='%.10e', delimiter = ', ')
np.savetxt("CMPC_U.txt", U, fmt='%.10e', delimiter = ', ')
np.savetxt("CMPC_Z.txt", Z, fmt='%.10e', delimiter = ', ')
np.savetxt("CMPC_X.txt", X, fmt='%.10e', delimiter = ', ')
np.savetxt("CMPC_Y.txt", Y, fmt='%.10e', delimiter = ', ')

np.savetxt("CMPC_Y_setpnt.txt", Y_setpnt_r, fmt='%.10e', delimiter = ', ')
np.savetxt("CMPC_Y_zone_setpnt.txt", Y_zone_setpnt_r, fmt='%.10e', delimiter = ', ')
np.savetxt("CMPC_Distb.txt", D_r, fmt='%.10e', delimiter = ', ')

np.savetxt("CMPC_long_U_solution.txt", U_l_solution, fmt='%.10e', delimiter = ', ')
np.savetxt("CMPC_long_X_solution.txt", X_l_solution, fmt='%.10e', delimiter = ', ')
np.savetxt("CMPC_long_X_last_solution.txt", X_l_last_solution, fmt='%.10e', delimiter = ', ')
np.savetxt("CMPC_long_U_last_solution.txt", U_l_last_solution, fmt='%.10e', delimiter = ', ')
np.savetxt("CMPC_long_Z_last_solution.txt", Z_l_last_solution, fmt='%.10e', delimiter = ', ')

np.savetxt("CMPC_long_Y_setpnt.txt", Y_setpnt_l, fmt='%.10e', delimiter = ', ')
np.savetxt("CMPC_long_Y_zone_setpnt.txt", Y_zone_setpnt_l, fmt='%.10e', delimiter = ', ')
np.savetxt("CMPC_long_Distb.txt", D_l, fmt='%.10e', delimiter = ', ')
np.savetxt("CMPC_long_time.txt", time_LMPC, fmt='%.10e', delimiter = ', ')